前言

这是屯了一道很久的题目,昨天做完今天补一下题解,其实也并不是很难,只要想清楚一些细节就好了

题面

树上区间覆盖,区间查询颜色段数

解题思路

由易到难是我们数学老师经常说的一句话,所以我们这里也来乱搞应用一下

0、预备

首先,这当然是一道树剖的题,然后加一个线段树维护区间颜色段数,考虑到合并时候最多就是中间会少去一段

$1\ 3\ 4$ 和 $4\ 2 \ 1$合并时,中间的$4$会多算一遍,所以要减去

也就是说,我们需要记录三个值

$l$和$r$是当前区间的左右端点,用$k$表示当前区间的编号

$f[k]\ =\ l$到$r$这段区间中颜色段数

$fl[k]\ =\ $最左边节点的颜色,也就是$l$的颜色

$fr[k]\ =\ $最右边节点的颜色,即$r$的颜色

于是$built$就可以这样写:

inline void Upd(register int k,register int cur) {
    f[k]=f[cur]+f[cur|1]-(fr[cur]==fl[cur|1]);
    fl[k]=fl[cur],fr[k]=fr[cur|1];
}
void built(register int k,register int l,register int r) {
    if(l>r) return;
    if(l==r) {fl[k]=fr[k]=a[rev[l]];f[k]=1;return;}//不要忘记rev数组
    register int mid=(l+r)>>1,cur=k<<1;
    built(cur,l,mid);
    built(cur|1,mid+1,r);
    Upd(k,cur);
}

1、对于不在树上的修改&查询

修改区间覆盖,肯定是直接覆盖不用说的,区间修改注意打标记的方式,然后再和$built$的时候一样$Upd$一下就好了,感性理解这样是对的

对于查询,线段树的查询方式是将$[L,R]$这个区间分成$Log$块,然后依次查询$[L,R]$的每一块,很明显,查询的顺序是从左到右的,那么和$built$的时候一样,$Upd$把左右节点相同颜色的情况减掉即可

$Las\ =\ $上一次查询的块 最右节点的颜色($Las$的初始值为一个不可能在序列中出现的数,比如$0$)

然后就可以这样写查询:

void Query(register int k,register int l,register int r) {
    if(tag[k]) push(k,l,r);
    if(r<L||R<l) return;
    if(L<=l&&r<=R) {
        if(Las==fl[k]) --ans;
        ans+=f[k];Las=fr[k];
        if(l==L) Lres=fl[k];
        if(r==R) Rres=fr[k];
        return;
    }
    register int mid=(l+r)>>1,cur=k<<1;
    Query(cur,l,mid);
    Query(cur|1,mid+1,r);
}

2、对于在树上的修改

虽然会分成很多链,但实际上人的本质是不变的,只要对于每条重链$Modify$一下就好了

3、对于在树上的查询

现在不仅是块了,还有很多条链相连,所以还要考虑链的影响,但其实链的本质和上文所述块是一样的,也只要记录一下相邻的两个点就好了,为了方便,我直接全部记录在数组里,最后再一起遍历了,不要这种写法也是可以的,开几个变量然后在过程中判断就好了

附上超多的代码:

#define fx top[x]
#define fy top[y]
inline void Uni() {
    register int i;
    for(i=2;i<p[0][0];i+=2)
        ans-=(p[i][0]==p[i+1][0]);
    for(i=2;i<p[0][1];i+=2)
        ans-=(p[i][1]==p[i+1][1]);
    if(p[p[0][0]][0]==p[p[0][1]][1]) --ans;
    printf("%d\n",ans);
}
inline void Ask() {
    p[0][0]=p[0][1]=ans=0;//表示的是公共祖先左边这条链和右边这条链的键值个数
    register int l=0;//表示当前的x是在公共祖先的哪一边,初始的时候默认x一侧为0
    while(fx!=fy) {
        if(deep[fx]<deep[fy]) swap(x,y),l^=1;
        L=seg[fx],R=seg[x],Las=0;
        Query(1,1,num);
        p[++p[0][l]][l]=Rres,p[++p[0][l]][l]=Lres;//就是把键值记录进去
        x=fa[fx];
    }
    if(deep[x]<deep[y]) swap(x,y),l^=1;
    L=seg[y],R=seg[x],Las=0;
    Query(1,1,num);
    p[++p[0][l]][l]=Rres,p[++p[0][l]][l]=Lres;
    Uni();//最后再一起搞
    return;
}

Code

#include<bits/stdc++.h>
#define getchar() *(pos++)
#define fx top[x]
#define fy top[y]
#define N 100010
using namespace std;
struct node{
    int to,nxt;
}b[N<<1];
int head[N],seg[N],top[N],rev[N],tag[N<<2],f[N<<2];
int deep[N],Sz[N],son[N],fa[N],fl[N<<2],fr[N<<2];
int n,T,C,t,num,x,y,a[N],p[N<<1][2],L,R,Lres,Rres,ans,Las;
char bf[1<<25],*pos;
inline int read() {
    register int s=0;
    register char c=getchar();
    while(!isdigit(c)) c=getchar();
    while(isdigit(c)) s=(s<<1)+(s<<3)+c-'0',c=getchar();
    return s;
}
inline void add(register int x,register int y) {
    b[++t].to=y,b[t].nxt=head[x],head[x]=t;
    b[++t].to=x,b[t].nxt=head[y],head[y]=t;
}
void dfs1(int k) {
    int i,to;
    deep[k]=deep[fa[k]]+1;Sz[k]=1;
    for(i=head[k];i;i=b[i].nxt)
    {
        to=b[i].to;
        if(to==fa[k]) continue;
        fa[to]=k,dfs1(to);
        Sz[k]+=Sz[to];
        if(Sz[to]>Sz[son[k]])
            son[k]=to;
    }
}
void dfs2(int k) {
    if(son[k]) {
        seg[son[k]]=++num;
        top[son[k]]=top[k];
        rev[num]=son[k];
        dfs2(son[k]);
    }
    int i,to;
    for(i=head[k];i;i=b[i].nxt) {
        to=b[i].to;
        if(top[to]) continue;
        seg[to]=++num;
        rev[num]=to;
        top[to]=to;
        dfs2(to);
    }
}
inline void Upd(register int k,register int cur) {
    f[k]=f[cur]+f[cur|1]-(fr[cur]==fl[cur|1]);
    fl[k]=fl[cur],fr[k]=fr[cur|1];
}
void built(register int k,register int l,register int r) {
    if(l>r) return;
    if(l==r) {fl[k]=fr[k]=a[rev[l]];f[k]=1;return;}
    register int mid=(l+r)>>1,cur=k<<1;
    built(cur,l,mid);
    built(cur|1,mid+1,r);
    Upd(k,cur);
}
inline void push(register int k,register int l,register int r) {
    f[k]=1;
    fl[k]=fr[k]=tag[k];
    if(l!=r) {
        register int cur=k<<1;
        tag[cur]=tag[cur|1]=tag[k];
    }
    tag[k]=0;
}
void Modify(register int k,register int l,register int r) {
    if(tag[k]) push(k,l,r);
    if(r<L||R<l) return;
    if(L<=l&&r<=R) {
        tag[k]=C;
        push(k,l,r);
        return;
    }
    register int mid=(l+r)>>1,cur=k<<1;
    Modify(cur,l,mid);
    Modify(cur|1,mid+1,r);
    Upd(k,cur);
}
void Query(register int k,register int l,register int r) {
    if(tag[k]) push(k,l,r);
    if(r<L||R<l) return;
    if(L<=l&&r<=R) {
        if(Las==fl[k]) --ans;
        ans+=f[k];Las=fr[k];
        if(l==L) Lres=fl[k];
        if(r==R) Rres=fr[k];
        return;
    }
    register int mid=(l+r)>>1,cur=k<<1;
    Query(cur,l,mid);
    Query(cur|1,mid+1,r);
}
inline void Add() {//修改就放一起了,直接模板式改就好了
    C=read();
    while(fx!=fy) {
        if(deep[fx]<deep[fy]) swap(x,y);
        L=seg[fx],R=seg[x];
        Modify(1,1,num);
        x=fa[fx];
    }
    if(deep[x]<deep[y]) swap(x,y);
    L=seg[y],R=seg[x];
    Modify(1,1,num);
    return;
}
inline void Uni() {
    register int i;
    for(i=2;i<p[0][0];i+=2)//首先要理解每条重链的键值成对出现,而两条链之间的键值才是有用的
        ans-=(p[i][0]==p[i+1][0]);
    for(i=2;i<p[0][1];i+=2)
        ans-=(p[i][1]==p[i+1][1]);
    if(p[p[0][0]][0]==p[p[0][1]][1]) --ans;//最后尾连接的时候也要判断一下
    printf("%d\n",ans);
}
inline void Ask() {
    p[0][0]=p[0][1]=ans=0;
    register int l=0;
    while(fx!=fy) {
        if(deep[fx]<deep[fy]) swap(x,y),l^=1;
        L=seg[fx],R=seg[x],Las=0;
        Query(1,1,num);
        p[++p[0][l]][l]=Rres,p[++p[0][l]][l]=Lres;
        x=fa[fx];
    }
    if(deep[x]<deep[y]) swap(x,y),l^=1;
    L=seg[y],R=seg[x],Las=0;
    Query(1,1,num);
    p[++p[0][l]][l]=Rres,p[++p[0][l]][l]=Lres;
    Uni();
    return;
}
int main()
{
    int i;
    char c;
    bf[fread(bf,1,1<<25,stdin)]='\0',pos=bf;
    n=read();T=read();
    for(i=1;i<=n;i++) a[i]=read();
    for(i=1;i<n;i++) x=read(),y=read(),add(x,y);
    num=seg[1]=rev[1]=top[1]=1;
    dfs1(1),dfs2(1),built(1,1,num);
    while(T--)
    {
        c=getchar();
        while(c!='C'&&c!='Q') c=getchar();
        x=read(),y=read();
        if(c=='C') Add();
        else Ask();
    }
    return 0;
}

devil.